{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Segmenting remote sensing imagery with point prompts\n",
    "\n",
    "[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/segment-geospatial/blob/main/docs/examples/sam2_point_prompts.ipynb)\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/sam2_point_prompts.ipynb)\n",
    "\n",
    "This notebook shows how to generate object masks from point prompts with the Segment Anything Model 2 (SAM 2). \n",
    "\n",
    "Make sure you use GPU runtime for this notebook. For Google Colab, go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install dependencies\n",
    "\n",
    "Uncomment and run the following cell to install the required dependencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install -U segment-geospatial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import leafmap\n",
    "from samgeo import SamGeo2\n",
    "from samgeo.common import regularize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create an interactive map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = leafmap.Map(center=[47.653287, -117.588070], zoom=16, height=\"800px\")\n",
    "m.add_basemap(\"Satellite\")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download a sample image\n",
    "\n",
    "Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map. If no geometry is drawn, the default bounding box will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if m.user_roi is not None:\n",
    "    bbox = m.user_roi_bounds()\n",
    "else:\n",
    "    bbox = [-117.6029, 47.65, -117.5936, 47.6563]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = \"satellite.tif\"\n",
    "leafmap.map_tiles_to_geotiff(\n",
    "    output=image, bbox=bbox, zoom=18, source=\"Satellite\", overwrite=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use your own image. Uncomment and run the following cell to use your own image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = '/path/to/your/own/image.tif'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the downloaded image on the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.layers[-1].visible = False\n",
    "m.add_raster(image, layer_name=\"Image\")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize SAM class\n",
    "\n",
    "Set `automatic=False` to enable the `SAM2ImagePredictor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam = SamGeo2(\n",
    "    model_id=\"sam2-hiera-large\",\n",
    "    automatic=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Specify the image to segment. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam.set_image(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segment the image\n",
    "\n",
    "Use the `predict_by_points()` method to segment the image with specified point coordinates. You can use the draw tools to add place markers on the map. If no point is added, the default sample points will be used.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if m.user_rois is not None:\n",
    "    point_coords_batch = m.user_rois\n",
    "else:\n",
    "    point_coords_batch = [\n",
    "        [-117.599896, 47.655345],\n",
    "        [-117.59992, 47.655167],\n",
    "        [-117.599928, 47.654974],\n",
    "        [-117.599518, 47.655337],\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Segment the objects using the point prompts and save the output masks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam.predict_by_points(\n",
    "    point_coords_batch=point_coords_batch,\n",
    "    point_crs=\"EPSG:4326\",\n",
    "    output=\"mask.tif\",\n",
    "    dtype=\"uint8\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Display the result\n",
    "\n",
    "Add the segmented image to the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.add_raster(\"mask.tif\", cmap=\"viridis\", nodata=0, opacity=0.7, layer_name=\"Mask\")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/49e413b9-e159-4d72-bf23-a0318bc82d44)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use an existing vector dataset as points prompts\n",
    "\n",
    "Alternatively, you can specify a file path or HTTP URL to a vector dataset containing point geometries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geojson = \"https://github.com/opengeos/datasets/releases/download/places/wa_building_centroids.geojson\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the vector dataawr on the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = leafmap.Map()\n",
    "m.add_raster(image, layer_name=\"Image\")\n",
    "m.add_circle_markers_from_xy(\n",
    "    geojson, radius=3, color=\"red\", fill_color=\"yellow\", fill_opacity=0.8\n",
    ")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/f0d3ff1e-15fa-4bd3-ac15-637e8d63527d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segment image with a vector dataset\n",
    "\n",
    "Segment the image using the specified file path to the vector dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_masks = \"building_masks.tif\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam.predict_by_points(\n",
    "    point_coords_batch=geojson,\n",
    "    point_crs=\"EPSG:4326\",\n",
    "    output=output_masks,\n",
    "    dtype=\"uint8\",\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the segmented masks on the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.add_raster(\n",
    "    output_masks, cmap=\"jet\", nodata=0, opacity=0.7, layer_name=\"Building masks\"\n",
    ")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/262e1a31-1648-47d2-9e71-c85ab15b1a5c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean up the result\n",
    "\n",
    "Remove small objects from the segmented masks, fill holes, and compute geometric properties."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_vector = \"building_vector.geojson\"\n",
    "out_image = \"buildings.tif\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "array, gdf = sam.region_groups(\n",
    "    output_masks, min_size=200, out_vector=out_vector, out_image=out_image\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gdf.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/af9ffa11-8ebe-4b42-8cba-3f5bcc4912f4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regularize building footprints\n",
    "\n",
    "Regularize the building footprints using the `regularize()` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_regularized = \"building_regularized.geojson\"\n",
    "regularize(out_vector, output_regularized)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the regularized building footprints on the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = leafmap.Map()\n",
    "m.add_raster(image, layer_name=\"Image\")\n",
    "style = {\n",
    "    \"color\": \"#ffff00\",\n",
    "    \"weight\": 2,\n",
    "    \"fillColor\": \"#7c4185\",\n",
    "    \"fillOpacity\": 0,\n",
    "}\n",
    "m.add_raster(out_image, cmap=\"tab20\", opacity=0.7, nodata=0, layer_name=\"Buildings\")\n",
    "m.add_vector(\n",
    "    output_regularized, style=style, layer_name=\"Building regularized\", info_mode=None\n",
    ")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/b39ee029-2089-45b8-8ac0-ba0d750cec22)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive segmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam.show_map()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://github.com/user-attachments/assets/4f487505-6e89-4892-9a70-95ab0aa69cb6)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
